import asyncio
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional

from loguru import logger
from pydantic import BaseModel, Field

from swarms.structs.agent import Agent
from swarms.structs.base_swarm import BaseSwarm
from swarms.utils.file_processing import create_file_in_folder


class AgentOutputSchema(BaseModel):
    run_id: Optional[str] = Field(
        ..., description="Unique ID for the run"
    )
    agent_name: Optional[str] = Field(
        ..., description="Name of the agent"
    )
    task: Optional[str] = Field(
        ..., description="Task or query given to the agent"
    )
    output: Optional[str] = Field(
        ..., description="Output generated by the agent"
    )
    start_time: Optional[datetime] = Field(
        ..., description="Start time of the task"
    )
    end_time: Optional[datetime] = Field(
        ..., description="End time of the task"
    )
    duration: Optional[float] = Field(
        ...,
        description="Duration taken to complete the task (in seconds)",
    )


class MetadataSchema(BaseModel):
    swarm_id: Optional[str] = Field(
        ..., description="Unique ID for the run"
    )
    task: Optional[str] = Field(
        ..., description="Task or query given to all agents"
    )
    description: Optional[str] = Field(
        "Concurrent execution of multiple agents",
        description="Description of the workflow",
    )
    agents: Optional[List[AgentOutputSchema]] = Field(
        ..., description="List of agent outputs and metadata"
    )
    timestamp: Optional[datetime] = Field(
        default_factory=datetime.now,
        description="Timestamp of the workflow execution",
    )


class ConcurrentWorkflow(BaseSwarm):
    """
    Represents a concurrent workflow that executes multiple agents concurrently.

    Args:
        name (str): The name of the workflow. Defaults to "ConcurrentWorkflow".
        description (str): The description of the workflow. Defaults to "Execution of multiple agents concurrently".
        agents (List[Agent]): The list of agents to be executed concurrently. Defaults to an empty list.
        metadata_output_path (str): The path to save the metadata output. Defaults to "agent_metadata.json".
        auto_save (bool): Flag indicating whether to automatically save the metadata. Defaults to False.
        output_schema (BaseModel): The output schema for the metadata. Defaults to MetadataSchema.

    Raises:
        ValueError: If the list of agents is empty or if the description is empty.

    Attributes:
        name (str): The name of the workflow.
        description (str): The description of the workflow.
        agents (List[Agent]): The list of agents to be executed concurrently.
        metadata_output_path (str): The path to save the metadata output.
        auto_save (bool): Flag indicating whether to automatically save the metadata.
        output_schema (BaseModel): The output schema for the metadata.

    """

    def __init__(
        self,
        name: str = "ConcurrentWorkflow",
        description: str = "Execution of multiple agents concurrently",
        agents: List[Agent] = [],
        metadata_output_path: str = "agent_metadata.json",
        auto_save: bool = True,
        output_schema: BaseModel = MetadataSchema,
        max_loops: int = 1,
        return_str_on: bool = False,
        agent_responses: list = [],
        *args,
        **kwargs,
    ):
        super().__init__(name=name, agents=agents, *args, **kwargs)
        self.name = name
        self.description = description
        self.agents = agents
        self.metadata_output_path = metadata_output_path
        self.auto_save = auto_save
        self.output_schema = output_schema
        self.max_loops = max_loops
        self.return_str_on = return_str_on
        self.agent_responses = agent_responses

        if not agents:
            raise ValueError("The list of agents cannot be empty.")

        if not description:
            raise ValueError("The description cannot be empty.")

    async def _run_agent(
        self, agent: Agent, task: str, executor: ThreadPoolExecutor
    ) -> AgentOutputSchema:
        """
        Runs a single agent with the given task and tracks its output and metadata.

        Args:
            agent (Agent): The agent instance to run.
            task (str): The task or query to give to the agent.
            executor (ThreadPoolExecutor): The thread pool executor to use for running the agent task.

        Returns:
            AgentOutputSchema: The metadata and output from the agent's execution.
        """
        start_time = datetime.now()
        try:
            loop = asyncio.get_running_loop()
            output = await loop.run_in_executor(
                executor, agent.run, task
            )
        except Exception as e:
            output = f"Error: {e}"

        end_time = datetime.now()
        duration = (end_time - start_time).total_seconds()

        agent_output = AgentOutputSchema(
            run_id=uuid.uuid4().hex,
            agent_name=agent.agent_name,
            task=task,
            output=output,
            start_time=start_time,
            end_time=end_time,
            duration=duration,
        )

        logger.info(
            f"Agent {agent.agent_name} completed task: {task} in {duration:.2f} seconds."
        )

        return agent_output

    def transform_metadata_schema_to_str(
        self, schema: MetadataSchema
    ):
        """
        transform metadata schema to string

        converts the metadata swarm schema into a string format with the agent name, response, and time
        """
        self.agent_responses = [
            f"Agent Name: {agent.agent_name}\nResponse: {agent.output}\n\n"
            for agent in schema.agents
        ]

        # print all agent responses
        # print("\n".join(self.agent_responses))

        # Return the agent responses as a string
        return "\n".join(self.agent_responses)

    async def _execute_agents_concurrently(
        self, task: str
    ) -> MetadataSchema:
        """
        Executes multiple agents concurrently with the same task.

        Args:
            task (str): The task or query to give to all agents.

        Returns:
            MetadataSchema: The aggregated metadata and outputs from all agents.
        """
        with ThreadPoolExecutor(
            max_workers=os.cpu_count()
        ) as executor:
            tasks_to_run = [
                self._run_agent(agent, task, executor)
                for agent in self.agents
            ]

            agent_outputs = await asyncio.gather(*tasks_to_run)
        return MetadataSchema(
            swarm_id=uuid.uuid4().hex,
            task=task,
            description=self.description,
            agents=agent_outputs,
        )

    def run(self, task: str) -> Dict[str, Any]:
        """
        Runs the workflow for the given task, executes agents concurrently, and saves metadata.

        Args:
            task (str): The task or query to give to all agents.

        Returns:
            Dict[str, Any]: The final metadata as a dictionary.
        """
        logger.info(
            f"Running concurrent workflow with {len(self.agents)} agents."
        )
        self.output_schema = asyncio.run(
            self._execute_agents_concurrently(task)
        )

        # # Save metadata to a JSON file
        if self.auto_save:
            logger.info(
                f"Saving metadata to {self.metadata_output_path}"
            )
            create_file_in_folder(
                os.getenv("WORKSPACE_DIR"),
                self.metadata_output_path,
                self.output_schema.model_dump_json(indent=4),
            )

        if self.return_str_on:
            return self.transform_metadata_schema_to_str(
                self.output_schema
            )

        else:
            # Return metadata as a dictionary
            return self.output_schema.model_dump_json(indent=4)


# if __name__ == "__main__":
#     # Assuming you've already initialized some agents outside of this class
#     model = OpenAIChat(
#         api_key=os.getenv("OPENAI_API_KEY"),
#         model_name="gpt-4o-mini",
#         temperature=0.1,
#     )
#     agents = [
#         Agent(
#             agent_name=f"Financial-Analysis-Agent-{i}",
#             system_prompt=FINANCIAL_AGENT_SYS_PROMPT,
#             llm=model,
#             max_loops=1,
#             autosave=True,
#             dashboard=False,
#             verbose=True,
#             dynamic_temperature_enabled=True,
#             saved_state_path=f"finance_agent_{i}.json",
#             user_name="swarms_corp",
#             retry_attempts=1,
#             context_length=200000,
#             return_step_meta=False,
#         )
#         for i in range(3)  # Adjust number of agents as needed
#     ]

#     # Initialize the workflow with the list of agents
#     workflow = ConcurrentWorkflow(
#         agents=agents,
#         metadata_output_path="agent_metadata_4.json",
#         return_str_on=True,
#     )

#     # Define the task for all agents
#     task = "How can I establish a ROTH IRA to buy stocks and get a tax break? What are the criteria?"

#     # Run the workflow and save metadata
#     metadata = workflow.run(task)
#     print(metadata)
